import logging
import re

from collections import defaultdict
from lab.reports import Table, CellFormatter
from downward.reports import PlanningReport

def toString(f):
    return str("{0:.1f}".format(f)) 

def valid(attribute):
    return (attribute != -1 and attribute is not None) 

class ResultTables(PlanningReport):
    """
    If the experiment contains more than one algorithm, use
    ``filter_algorithm='my_algorithm'`` to select exactly two algorithms
    for the report. The algorithm need to contain the key words: astar and nbs.

    >>> from downward.experiment import FastDownwardExperiment
    >>> exp = FastDownwardExperiment()
    >>> exp.add_report(SummaryReport(
    ...     attributes=["expansions", "search_time"],
    ...     filter_algorithm=["astar_lmcut","nbs_lmcut"]))

    Output legend:
        'Actions:' --> minimum-(median)-maximum
        'Ex: Bi < Uni' --> # of problems where nbs expands less states.
        'Ex_jump: Bi < 2* Uni' --> same as above, with expansions before last jump.
        'Ex: algo:' --> min-(med)-max
        'h: B > F' --> # of problems h-value is bigger in backward direction at mp.
        'Initial Goals:' --> min-(med)-max
        'MP: B <-- F' --> # of times the mp is closer to the goal than start.
        'Ratio_ex_jump:' --> min-(med)-max --> before_jump_expanded / expanded
        'S: algo:' --> # of solved problems with corresponding algo

    """
    def __init__(self, **kwargs):
        PlanningReport.__init__(self,**kwargs)

    def _translate_mp(self, mp):
        if '-' in mp:
            return (0, 0)
        parts = mp.split(', ')
        for part in parts:
            if 'f_g: ' in part:
                f_g = int(part[5:])
            if 'b_g: ' in part:
                b_g = int(part[5:])
            if 'f_h: ' in part:
                f_h = int(part[5:])
            if 'b_h: ' in part:
                b_h = int(part[5:])
        h = 1 if (b_h >= f_h) else 0
        s = 1 if (b_g >= f_g) else 0
        return (h, s)

    def _get_table(self):
        kwargs = dict(
            colored=True)
        table = Table(title="Domain", **kwargs) 

        # Variables
        solved = {}
        better = {}
        formatter = CellFormatter(bold=True)

        # Set Properties
        for (domain, problem), runs in sorted(self.problem_runs.items()):
            ex = {}
            for run in runs:
                algo = run.get('algorithm')
                if (run.get('error') == 'success'):
                    solved[domain,algo] = solved.get((domain,algo),0) + 1 
                else:
                    solved[domain,algo] = solved.get((domain,algo),0)
                if (valid(run.get('expanded'))):
                    ex[algo] = run.get('expanded')
            for h in ['blind','max','hm','lmcut']:
                h_nbs = 'nbs_'+h
                h_astar = 'astar_'+h
                if (valid(ex.get(h_nbs,-1)) and ex.get(h_nbs) < ex.get(h_astar) ):
                    better[domain,h_nbs] = better.get((domain,h_nbs),0) + 1
                else:
                    better[domain,h_nbs] = better.get((domain,h_nbs),0)

        # Write cells
        out = {'nbs_blind':'nbs_a','nbs_max':'nbs_b','nbs_hm':'nbs_c','nbs_lmcut':'nbs_d','astar_blind':'astar_a','astar_max':'astar_b','astar_hm':'astar_c','astar_lmcut':'astar_d'}
        for (domain, algo), value in solved.items():
            algoC = re.sub('nbs','astar',algo) if algo.startswith('nbs') else re.sub('astar','nbs',algo)
            if value > solved[domain,algoC]:
                table.cell_formatters[domain][out[algo]] = formatter
            table.add_cell(domain,out[algo],str(value))
        for (domain, h), value in better.items():
            table.add_cell(domain,'X: '+out[h],str(value))
        return table


    def _get_table2(self):
        kwargs = dict(
            colored=True)
        table = Table(title="Domain", **kwargs) 

        # Variables
        ig = defaultdict(list)
        solved = {}
        better = {}
        better2 = {}
        better3 = {}
        formatter = CellFormatter(bold=True)

        # Set Properties
        for (domain, problem), runs in sorted(self.problem_runs.items()):
            ex = {}
            jump_ex = {}
            for run in runs:
                algo = run.get('algorithm')
                if (run.get('error') == 'success'):
                    solved[domain,algo] = solved.get((domain,algo),0) + 1 
                else:
                    solved[domain,algo] = solved.get((domain,algo),0)
                if (valid(run.get('expanded'))):
                    ex[algo] = run.get('expanded')
                if (valid(run.get('jump_expanded',-1))):
                    jump_ex[algo] = run.get('jump_expanded')
                if algo.startswith('nbs'):
                    (h, s) = self._translate_mp(run.get('meeting_point','-'))
                    better3[domain,algo] = better3.get((domain,algo),0) + s
                    if (valid(run.get('b_initial_goals')) and 'max' in algo):
                        ig[domain].append(run.get('b_initial_goals'))
            for h in ['blind','max','hm','lmcut']:
                h_nbs = 'nbs_'+h
                h_astar = 'astar_'+h
                if (valid(ex.get(h_nbs,-1)) and ex.get(h_nbs) < ex.get(h_astar) ):
                    better[domain,h_nbs] = better.get((domain,h_nbs),0) + 1
                if (valid(jump_ex.get(h_nbs,-1)) and jump_ex.get(h_nbs) < 2*jump_ex.get(h_astar,0)):
                    better2[domain,h_nbs] = better2.get((domain,h_nbs),0) + 1

        # Write cells
        out = {'blind':'a','max':'b','hm':'c','lmcut':'d'}
        for (domain, algo), value in solved.items():
            direction = 'f' if algo.startswith('astar') else 'b'
            h = re.sub('astar_','',re.sub('nbs_','',algo))
            table.add_cell(domain+':'+out[h],'b' if algo.startswith('nbs') else 'a',str(value))
            table.add_cell(domain+':'+out[h],'c',better.get((domain,algo),0))
            table.add_cell(domain+':'+out[h],'d',better2.get((domain,algo),0))
            if direction == 'b':
                table.add_cell(domain+':'+out[h],'e', better3.get((domain,algo),0))
        for (domain), values in ig.items():
            values.sort()
            table.add_cell(domain,'f_a',str(values[0]))
            table.add_cell(domain,'f_b',str(values[len(values)/2]))
            table.add_cell(domain,'f_c',str(values[-1]))
        return table

    # Prints the min, avg, max, and sample size of the attribute per domain.
    def _print_attribute(self, algo, attribute):
        print 'Printing action ratios per domain.'
        for domain in sorted(self.domains.keys()):
            values = []
            for problem in self.domains[domain]:
                run = self.runs[domain,problem,algo]
                att_value = run.get(attribute,-1)
                if (att_value != -1):
                    values.append(att_value)
            out = []
            out.append(domain)
            if values:
                values = sorted(list(filter(lambda x: x >= 0, values)))
                avg = reduce((lambda x,y: x+y), values) / len(values)
                out.extend([toString(values[0]), toString(avg), toString(values[-1]), str(len(values))])
            print ' & '.join(outs for outs in out) + '\\\\'

    # nbs_lmcut, ratio_fb_actions

    def _print_stuff(self):
        # - - Print action ratio
        #self._print_attribute('nbs_lmcut', 'ratio_fb_actions')
        # - - Print expansion ratio
        self._print_attribute('nbs_max', 'ratio_jump_expanded')

    def get_markup(self):
        self._print_stuff()
        tables = [self._get_table(), self._get_table2()]       
        return '\n'.join(str(table) for table in tables)


# List of properties:
# domains               : (domain), problems
# problems              : set (domain, problem)
# problem_runs          : (domain, problem), runs
# domain_algorithm_runs : (domain, algorithm), runs
# runs                  : (domain, problem, algo), run
# attributes
# algorithms            : set (algorithm)
# algorithm_info
